Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Remove unnecessary reshapes for batch_matmul #7675

Merged
merged 5 commits into from
Mar 17, 2021

Conversation

comaniac
Copy link
Contributor

This PR removes unnecessary reshape ops in the PyTorch frontend when converting to batch_matmul. This should help the performance of NLP models such as BERT.

cc @siju-samuel @masahi

@comaniac
Copy link
Contributor Author

Pushed a new commit to also reorder the reshape_b and transpose so that the simplify expression can be used.

Before this PR:

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
  %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(10, 4, 5), float32] */;
  %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
  %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
  reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
}

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), float32] */;
  %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(1, 4, 5), float32] */;
  %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(1, 5, 4), float32] */;
  %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */;
  reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
}

fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
  %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
  %1 = reshape(%input1, newshape=[-1, 64, 14]) /* ty=Tensor[(12, 64, 14), float32] */;
  %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(12, 14, 64), float32] */;
  %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
  reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
}

After this PR:

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = transpose(%input1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
  nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
}

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */;
  %1 = reshape(%0, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
  nn.batch_matmul(%input0, %1, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
}

fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
  %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
  %1 = transpose(%input1, axes=[0, 1, 3, 2]) /* ty=Tensor[(1, 12, 14, 64), float32] */;
  %2 = reshape(%1, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), float32] */;
  %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(12, 14, 14), float32] */;
  reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), float32] */
}

In particular, since the weights in most PyTorch models have to be transposed when converting to Relay, the second case, for example, could be:

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(4, 5), float32] */; <- Not added by matmul
  %1 = transpose(%0, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */; <- Added by matmul
  %2 = reshape(%1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
  nn.batch_matmul(%input0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
}

By applying SimplifyExpr to cancel unnecessary transpose, we could have:

fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) -> Tensor[(10, 3, 5), float32] {
  %0 = reshape(%input1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
  nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* ty=Tensor[(10, 3, 5), float32] */
}

@masahi masahi merged commit 4abbe49 into apache:main Mar 17, 2021
@masahi
Copy link
Member

masahi commented Mar 17, 2021

Thanks @comaniac

@comaniac comaniac deleted the pytorch_remove_reshape branch March 17, 2021 16:30
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
* [Torch] Remove unnecessary reshapes for batch_matmul

* lint

* fix

* reorder

* lint
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
* [Torch] Remove unnecessary reshapes for batch_matmul

* lint

* fix

* reorder

* lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants